Испольуется три задачи:
Сеть состоит из lif AdEx нейронов
import matplotlib.pyplot as plt
import numpy as np
import torch
import torch.nn as nn
from cgtasknet.net.lifadex import SNNlifadex
from cgtasknet.tasks.reduce import (
CtxDMTaskParameters,
DMTaskParameters,
DMTaskRandomModParameters,
MultyReduceTasks,
RomoTaskParameters,
RomoTaskRandomModParameters,
)
from norse.torch.functional.lif_adex import LIFAdExParameters
from tqdm import tqdm
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(f"{device=}")
device=device(type='cuda', index=0)
import os
def plot_results(inputs, target_outputs, outputs):
if isinstance(inputs, torch.Tensor) and isinstance(target_outputs, torch.Tensor):
inputs, t_outputs = (
inputs.detach().cpu().numpy(),
target_outputs.detach().cpu().numpy(),
)
for bath in range(min(batch_size, 20)):
fig = plt.figure(figsize=(15, 3))
ax1 = fig.add_subplot(141)
plt.title("Inputs")
plt.xlabel("$time, ms$")
plt.ylabel("$Magnitude$")
for i in range(3):
plt.plot(inputs[:, bath, i].T, label=rf"$in_{i + 1}$")
plt.legend()
plt.tight_layout()
ax2 = fig.add_subplot(142)
plt.title("Task code (context)")
plt.xticks(np.arange(1, len(tasks) + 1), sorted(tasks), rotation=90)
plt.yticks([])
for i in range(3, inputs.shape[-1]):
plt.plot([i - 2] * 2, [0, inputs[0, bath, i]])
plt.tight_layout()
ax3 = fig.add_subplot(143)
plt.title("Target output")
plt.xlabel("$time, ms$")
for i in range(t_outputs.shape[-1]):
plt.plot(t_outputs[:, bath, i], label=rf"$out_{i + 1}$")
plt.legend()
plt.tight_layout()
ax4 = fig.add_subplot(144)
plt.title("Real output")
plt.xlabel("$time, ms$")
for i in range(outputs.shape[-1]):
plt.plot(
outputs.detach().cpu().numpy()[:, bath, i], label=rf"$out_{i + 1}$"
)
plt.legend()
plt.tight_layout()
if not os.path.exists("figures"):
os.mkdir("figures")
plt.savefig(f"figures{os.sep}network_outputs_{name}_batch_{bath}.pdf")
plt.show()
plt.close()
batch_size = 100
number_of_epochs = 2000
number_of_tasks = 1
romo_parameters = RomoTaskRandomModParameters(
romo=RomoTaskParameters(
delay=0.1,
positive_shift_delay_time=1.4,
trial_time=0.1,
positive_shift_trial_time=0.2,
),
)
dm_parameters = DMTaskRandomModParameters(
dm=DMTaskParameters(trial_time=0.1, positive_shift_trial_time=0.8)
)
ctx_parameters = CtxDMTaskParameters(dm=dm_parameters.dm)
sigma = 0.5
tasks = ["RomoTask1", "RomoTask2", "DMTask1", "DMTask2", "CtxDMTask1", "CtxDMTask2"]
task_dict = {
tasks[0]: romo_parameters,
tasks[1]: romo_parameters,
tasks[2]: dm_parameters,
tasks[3]: dm_parameters,
tasks[4]: ctx_parameters,
tasks[5]: ctx_parameters,
}
Task = MultyReduceTasks(
tasks=task_dict, batch_size=batch_size, delay_between=0, enable_fixation_delay=True
)
print("Task parameters:")
for key in task_dict:
print(f"{key}:\n{task_dict[key]}\n")
print(f"inputs/outputs: {Task.feature_and_act_size[0]}/{Task.feature_and_act_size[1]}")
Task parameters: RomoTask1: RomoTaskRandomModParameters(romo=RomoTaskParameters(dt=0.001, trial_time=0.1, answer_time=0.15, value=(None, None), delay=0.1, negative_shift_trial_time=0, positive_shift_trial_time=0.2, negative_shift_delay_time=0, positive_shift_delay_time=1.4), n_mods=2) RomoTask2: RomoTaskRandomModParameters(romo=RomoTaskParameters(dt=0.001, trial_time=0.1, answer_time=0.15, value=(None, None), delay=0.1, negative_shift_trial_time=0, positive_shift_trial_time=0.2, negative_shift_delay_time=0, positive_shift_delay_time=1.4), n_mods=2) DMTask1: DMTaskRandomModParameters(dm=DMTaskParameters(dt=0.001, trial_time=0.1, answer_time=0.15, value=None, negative_shift_trial_time=0, positive_shift_trial_time=0.8), n_mods=2) DMTask2: DMTaskRandomModParameters(dm=DMTaskParameters(dt=0.001, trial_time=0.1, answer_time=0.15, value=None, negative_shift_trial_time=0, positive_shift_trial_time=0.8), n_mods=2) CtxDMTask1: CtxDMTaskParameters(dm=DMTaskParameters(dt=0.001, trial_time=0.1, answer_time=0.15, value=None, negative_shift_trial_time=0, positive_shift_trial_time=0.8), context=None, value=(None, None)) CtxDMTask2: CtxDMTaskParameters(dm=DMTaskParameters(dt=0.001, trial_time=0.1, answer_time=0.15, value=None, negative_shift_trial_time=0, positive_shift_trial_time=0.8), context=None, value=(None, None)) inputs/outputs: 9/3
inputs, t_outputs = Task.dataset(n_trials=1)
for bath in range(min(batch_size, 10)):
fig = plt.figure(figsize=(15, 3))
ax1 = fig.add_subplot(131)
plt.title("Inputs")
plt.xlabel("$time, ms$")
plt.ylabel("$Magnitude$")
for i in range(3):
plt.plot(inputs[:, bath, i], label=rf"$in_{i + 1}$")
plt.legend()
plt.tight_layout()
ax2 = fig.add_subplot(132)
plt.title("Task code (context)")
plt.xlabel("$time, ms$")
for i in range(3, inputs.shape[-1]):
plt.plot(inputs[:, bath, i], label=rf"$in_{i + 1}$")
plt.legend()
plt.tight_layout()
ax3 = fig.add_subplot(133)
plt.title("Target output")
plt.xlabel("$time, ms$")
for i in range(t_outputs.shape[-1]):
plt.plot(t_outputs[:, bath, i], label=rf"$out_{i + 1}$")
plt.legend()
plt.tight_layout()
plt.show()
plt.close()
del inputs
del t_outputs
feature_size, output_size = Task.feature_and_act_size
hidden_size = 450
neuron_parameters = LIFAdExParameters(
v_th=torch.as_tensor(0.65),
tau_ada_inv=0.5 + (6 - 0.5) * torch.rand(hidden_size).to(device),
alpha=100,
method="super",
# rho_reset = torch.as_tensor(5)
)
model = SNNlifadex(
feature_size,
hidden_size,
output_size,
neuron_parameters=neuron_parameters,
tau_filter_inv=500,
).to(device)
learning_rate = 1e-2
class RMSELoss(nn.Module):
def __init__(self):
super().__init__()
self.mse = nn.MSELoss()
def forward(self, yhat, y):
return torch.sqrt(self.mse(yhat, y))
criterion = nn.MSELoss()
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)
Если память не позволяет, то необходимо генерировать каждую эпоху в основном цикле обучения
if False:
list_inputs = []
list_t_outputs = []
for i in tqdm(range(number_of_epochs)):
temp_input, temp_t_output = Task.dataset()
temp_input.astype(dtype=np.float16)
temp_t_output.astype(dtype=np.float16)
temp_input[:, :, :] += np.random.normal(0, sigma, size=temp_input.shape)
list_inputs.append(temp_input)
list_t_outputs.append(temp_t_output)
from cgtasknet.instruments.instrument_accuracy_network import correct_answer
from cgtasknet.net.states import LIFAdExRefracInitState
name = f"Train_dm_and_romo_task_reduce_lif_adex_without_refrac_random_delay_long_a_alpha_{neuron_parameters.alpha}_N_{hidden_size}"
init_state = LIFAdExRefracInitState(batch_size, hidden_size, device=device)
running_loss = 0
for i in tqdm(range(2000)):
inputs, target_outputs = Task.dataset()
inputs[:, :, :3] += np.random.normal(0, np.random.uniform(0, sigma), size=inputs[:, :, :3].shape)
inputs = torch.from_numpy(inputs).type(torch.float).to(device)
target_outputs = torch.from_numpy(target_outputs).type(torch.float).to(device)
optimizer.zero_grad()
# forward + backward + optimize
outputs, _ = model(inputs)
loss = criterion(outputs, target_outputs)
loss.backward()
optimizer.step()
# print statistics
running_loss += loss.item()
if i % 10 == 9:
with open("log_multy.txt", "a") as f:
f.write("epoch: {:d} loss: {:0.5f}\n".format(i + 1, running_loss / 10))
running_loss = 0.0
with torch.no_grad():
torch.save(
model.state_dict(),
name,
)
if i % 10 == 9:
result = 0
for j in range(10):
try:
del inputs
except:
pass
try:
del target_outputs
except:
pass
try:
del outputs
except:
pass
torch.cuda.empty_cache()
inputs, target_outputs = Task.dataset(1, delay_between=0)
inputs += np.random.normal(0, 0.01, size=inputs.shape)
inputs = torch.from_numpy(inputs).type(torch.float).to(device)
target_outputs = (
torch.from_numpy(target_outputs).type(torch.float).to(device)
)
outputs = model(inputs)[0]
answers = correct_answer(
outputs[:, :, 1:], target_outputs[:, :, 1:], target_outputs[:, :, 0]
)
result += torch.sum(answers).item()
accuracy = result / batch_size / 10 * 100
with open("accuracy_multy.txt", "a") as f:
f.write(f"ecpoch = {i}; correct/all = {accuracy}\n")
try:
del inputs
except:
pass
try:
del target_outputs
except:
pass
try:
del outputs
except:
pass
torch.cuda.empty_cache()
print("Finished Training")
100%|██████████| 2000/2000 [3:44:16<00:00, 6.73s/it]
Finished Training
np.random.normal(0, 0.01, size=(inputs.shape))
result = 0
for j in tqdm(range(100)):
try:
del inputs
except:
pass
try:
del target_outputs
except:
pass
try:
del outputs
except:
pass
torch.cuda.empty_cache()
inputs, target_outputs = Task.dataset(1, delay_between=0)
inputs[:, :, :3] += np.random.normal(0, 0.01, size=inputs[:, :, :3].shape)
inputs = torch.from_numpy(inputs).type(torch.float).to(device)
target_outputs = torch.from_numpy(target_outputs).type(torch.float).to(device)
outputs = model(inputs)[0]
answers = correct_answer(
outputs[:, :, 1:], target_outputs[:, :, 1:], target_outputs[:, :, 0]
)
result += torch.sum(answers).item()
accuracy = result / batch_size / 100 * 100
print(accuracy)
plot_results(inputs, target_outputs, outputs)
try:
del inputs
except:
pass
try:
del target_outputs
except:
pass
try:
del outputs
except:
pass
torch.cuda.empty_cache()
100%|██████████| 100/100 [03:39<00:00, 2.20s/it]
93.41
np.random.normal(0, 0.05, size=(inputs.shape))
result = 0
for j in tqdm(range(100)):
try:
del inputs
except:
pass
try:
del target_outputs
except:
pass
try:
del outputs
except:
pass
torch.cuda.empty_cache()
torch.cuda.empty_cache()
inputs, target_outputs = Task.dataset(1, delay_between=0)
inputs[:, :, :3] += np.random.normal(0, 0.05, size=inputs[:, :, :3].shape)
inputs = torch.from_numpy(inputs).type(torch.float).to(device)
target_outputs = torch.from_numpy(target_outputs).type(torch.float).to(device)
outputs = model(inputs)[0]
answers = correct_answer(
outputs[:, :, 1:], target_outputs[:, :, 1:], target_outputs[:, :, 0]
)
result += torch.sum(answers).item()
# del inputs
# del target_outputs
# torch.cuda.empty_cache()
accuracy = result / batch_size / 100 * 100
print(accuracy)
plot_results(inputs, target_outputs, outputs)
try:
del inputs
except:
pass
try:
del target_outputs
except:
pass
try:
del outputs
except:
pass
torch.cuda.empty_cache()
100%|██████████| 100/100 [03:40<00:00, 2.21s/it]
93.07
np.random.normal(0, 0.1, size=(inputs.shape))
result = 0
for j in tqdm(range(100)):
try:
del inputs
except:
pass
try:
del target_outputs
except:
pass
try:
del outputs
except:
pass
torch.cuda.empty_cache()
inputs, target_outputs = Task.dataset(1, delay_between=0)
inputs[:, :, :3] += np.random.normal(0, 0.1, size=inputs[:, :, :3].shape)
inputs = torch.from_numpy(inputs).type(torch.float).to(device)
target_outputs = torch.from_numpy(target_outputs).type(torch.float).to(device)
outputs = model(inputs)[0]
answers = correct_answer(
outputs[:, :, 1:], target_outputs[:, :, 1:], target_outputs[:, :, 0]
)
result += torch.sum(answers).item()
accuracy = result / batch_size / 100 * 100
print(accuracy)
plot_results(inputs, target_outputs, outputs)
try:
del inputs
except:
pass
try:
del target_outputs
except:
pass
try:
del outputs
except:
pass
torch.cuda.empty_cache()
100%|██████████| 100/100 [03:39<00:00, 2.19s/it]
93.11
np.random.normal(0, 0.5, size=(inputs.shape))
result = 0
for j in tqdm(range(100)):
try:
del inputs
except:
pass
try:
del target_outputs
except:
pass
try:
del outputs
except:
pass
torch.cuda.empty_cache()
inputs, target_outputs = Task.dataset(1, delay_between=0)
inputs[:, :, :3] += np.random.normal(0, 0.5, size=inputs[:, :, :3].shape)
inputs = torch.from_numpy(inputs).type(torch.float).to(device)
target_outputs = torch.from_numpy(target_outputs).type(torch.float).to(device)
outputs = model(inputs)[0]
answers = correct_answer(
outputs[:, :, 1:], target_outputs[:, :, 1:], target_outputs[:, :, 0]
)
result += torch.sum(answers).item()
accuracy = result / batch_size / 100 * 100
print(accuracy)
plot_results(inputs, target_outputs, outputs)
try:
del inputs
except:
pass
try:
del target_outputs
except:
pass
try:
del outputs
except:
pass
torch.cuda.empty_cache()
100%|██████████| 100/100 [03:38<00:00, 2.18s/it]
89.84
result = 0
for j in tqdm(range(1)):
try:
del inputs
except:
pass
try:
del target_outputs
except:
pass
try:
del outputs
except:
pass
torch.cuda.empty_cache()
inputs, target_outputs = Task.dataset(1, delay_between=0)
inputs[:, :, :3] += np.random.normal(0, 0.5, size=inputs[:, :, :3].shape)
inputs = torch.from_numpy(inputs).type(torch.float).to(device)
target_outputs = torch.from_numpy(target_outputs).type(torch.float).to(device)
outputs = model(inputs)[0]
answers = correct_answer(
outputs[:, :, 1:], target_outputs[:, :, 1:], target_outputs[:, :, 0]
)
result += torch.sum(answers).item()
accuracy = result / batch_size / 100 * 100
print(accuracy)
plot_results(inputs, target_outputs, outputs)
try:
del inputs
except:
pass
try:
del target_outputs
except:
pass
try:
del outputs
except:
pass
torch.cuda.empty_cache()
100%|██████████| 1/1 [00:02<00:00, 2.62s/it]
0.8500000000000001
result = 0
for j in tqdm(range(1)):
try:
del inputs
except:
pass
try:
del target_outputs
except:
pass
try:
del outputs
except:
pass
torch.cuda.empty_cache()
inputs, target_outputs = Task.dataset(1, delay_between=0)
inputs[:, :, :3] += np.random.normal(0, 0.7, size=inputs[:, :, :3].shape)
inputs = torch.from_numpy(inputs).type(torch.float).to(device)
target_outputs = torch.from_numpy(target_outputs).type(torch.float).to(device)
outputs = model(inputs)[0]
answers = correct_answer(
outputs[:, :, 1:], target_outputs[:, :, 1:], target_outputs[:, :, 0]
)
result += torch.sum(answers).item()
accuracy = result / batch_size / 100 * 100
print(accuracy)
plot_results(inputs, target_outputs, outputs)
try:
del inputs
except:
pass
try:
del target_outputs
except:
pass
try:
del outputs
except:
pass
torch.cuda.empty_cache()
100%|██████████| 1/1 [00:02<00:00, 2.49s/it]
0.86
inputs = 0
outputs = 0
tau_ada_inv_distrib = neuron_parameters.tau_ada_inv.cpu().numpy()
np.save(f"tau_ada_inv_alpha={neuron_parameters.alpha}", tau_ada_inv_distrib)
lines = []
with open("accuracy_multy.txt", "r") as f:
while line := f.readline():
lines.append(float(line.split("=")[2].strip()))
plt.figure(figsize=(8, 5))
plt.plot([*range(9, 2000, 10)], lines, ".", linestyle="--", markersize=5)
plt.ylabel(r"Accuracy%")
plt.xlabel(r"Epochs")
Text(0.5, 0, 'Epochs')
start_sigma = 1.5
stop_sigma = 2
step_sigma = 0.05
sigma_array = np.arange(start_sigma, stop_sigma, step_sigma)
for test_sigma in tqdm(sigma_array):
result = 0
for j in (range(20)):
try:
del inputs
except:
pass
try:
del target_outputs
except:
pass
try:
del outputs
except:
pass
torch.cuda.empty_cache()
inputs, target_outputs = Task.dataset(1, delay_between=0)
inputs[:, :, :3] += np.random.normal(0, test_sigma, size=inputs[:, :, :3].shape)
inputs = torch.from_numpy(inputs).type(torch.float).to(device)
target_outputs = torch.from_numpy(target_outputs).type(torch.float).to(device)
outputs = model(inputs)[0]
answers = correct_answer(
outputs[:, :, 1:], target_outputs[:, :, 1:], target_outputs[:, :, 0]
)
result += torch.sum(answers).item()
accuracy = result / batch_size / 20 * 100
with open('accuracy_vs_noise.txt', 'a') as f:
f.write(f'noise={test_sigma}:accuracy={accuracy}\n')
100%|██████████| 10/10 [07:11<00:00, 43.10s/it]
import matplotlib.patches as patches
plt.style.use('ggplot')
def parser(line_text: str) -> tuple:
"""
Function parses text in form:
```v_name_1=v1:v_name_2:v2```
and return (v1, v2)
:param line_text:
:return: (v1, v2)
"""
line_text = line_text.split(':')
v1 = line_text[0].split('=')[1]
v2 = line_text[1].split('=')[1]
return float(v1), float(v2)
x, y = [], []
#with open('accuracy_vs_noise.txt', 'r') as f:
with open(r'A:\src\multy_task\notebooks\train\reduce\lif_adex\romo_dm_ctx\accuracy_vs_noise.txt', 'r') as f:
while line := f.readline():
t_x, t_y = parser(line)
x.append(t_x)
y.append(t_y)
fig, ax = plt.subplots()
ax.plot(x, y, '.', linestyle='--')
#ax.plot([.5]*2, [50, 100])
ax.set_ylabel('Accuracy,%')
ax.set_xlabel(r'$\sigma$')
ax.add_patch(
patches.Rectangle(
(0, 50),
.5,
50,
edgecolor = 'grey',
facecolor = 'grey',
alpha=.5,
fill=True
) )
plt.show()
plt.close()